import sys
from BasicLibraries import *
import pickle
import allocation_ga as af

#====
ONEDRIVE = ''
#====
AM = af.AllocationGAD()
#==== Binarisation quantile (for the robustness tests)
try:
    BQL = sys.argv[1]
    energy_control = sys.argv[2]
except Exception:
    #=== If no input is given, then go for the default binarisation threshold
    BQL=str(0.75)

#=== Default choice to control for energy prices
energy_control = False
print('Binarisation threshold:', BQL)
print('Control for energy prices for robustness tests:', energy_control)
BQL = float(BQL)

#==== Choose your population
population_type_ = 'LandscapePopulation_0724_full.txt'
#====
ONEDRIVE = ''
#================
np.random.seed(5)
#================

I = [        'association',
             'adoption of standards and rules',
             'assessment and measurement',
             'organizational structuring',
             'modification of procedures',
             'asset modification',
             'training',
             'r&d investments',
             'new products',
     ]
env_sdgs = ['water & energy', 'cons. & prod.', 'biodiversity']


def mutate_genes(genes, rnd_i):
    rnd_x = genes.copy()
    rnd_x[rnd_i] = rnd_x[rnd_i].replace([0,1], [1,0])           
    return rnd_x

def gamma_epistasis(V, model, randomisation_vars, N = 2):

    mutants = np.random.choice(randomisation_vars, N, replace = False)  
    nextdoor = mutants[:int(N/2)]
    mutation_dimensions = mutants[int(N/2):]



    ## This is the fixed mutation to the original locus
    V1 =  mutate_genes(V, mutation_dimensions)
    fg =  model.predict(np.array(V))    
    fg1 = model.predict(np.array(V1))   
    
    sj = (fg1 - fg)

    ## This is the neighbor locus
    VB =  mutate_genes(V, nextdoor)

    ## This is the fixed mutation to the neighbor locus   
    V1B = mutate_genes(VB, mutation_dimensions)
    fgB = model.predict(np.array(VB))    
    fg1B = model.predict(np.array(V1B))
    
    sjB = (fg1B - fgB)

    
    return sj,sjB



#%%
epistasis = []
full_data = pd.DataFrame()
for FW in [0.,0.25,0.5,0.75, 1.0]:
    print('Realisation:', FW)
    population_type = str(FW).replace('.', '_')+'_'+population_type_
    dt = pd.read_csv(population_type, low_memory=False)
    gvkey_list = list(np.unique(dt.gvkey))
    
    dat = dt.copy()
    dat = dat[dat.number_of_initiatives > 0]
    genesF, obs_fitnessF, forest_landscape, rSDG, _, _, _,_,_ = AM.fit_model(dat,  
                                                                          sdgs_number=env_sdgs,
                                                                          dimension_ = I, \
                                                                          gvkey_list = gvkey_list,
                                                                          model_type = 'RandomForest',
                                                                          normal_fit=False,
                                                                          energy_price_control = energy_control,
                                                                          binarization_quantile = BQL, 
                                                                          FW=FW) 
    vars_to_randomise = genesF.columns[:len(rSDG)]
    others = genesF.columns[len(rSDG):]
    
    #===
    for N in [2,4,8,16]:
        res = []
        it_ = 100     
        btilde = genesF.copy()
        btilde = btilde.reset_index()
        idx =     btilde[vars_to_randomise].sum(axis = 1)
        idx = idx[idx < 2].dropna().index
        btilde=  btilde.drop(index = idx).set_index('gvkey')
        for _ in range(it_):   
            b = btilde.copy()  
            b = b.iloc[np.random.choice(int(len(b)), int(0.9*len(b)), replace=True)].reset_index(drop=True)
            ge = gamma_epistasis(b, forest_landscape, vars_to_randomise, N)
            ge = pearsonr(ge[0], ge[1])[0]
            res.append([1-ge, FW, N])
            
        #==
        l = pd.DataFrame(res, columns = ['Gamma',   'Weight', 'Mutations'])
        l = l[(l.Gamma < l.Gamma.quantile(0.99)) & (l.Gamma > l.Gamma.quantile(0.01))]
        l = l.dropna().reset_index(drop=True)
        ge_lb, ge_ub = np.quantile(l['Gamma'], 0.05), np.quantile(l['Gamma'], 0.95) 

        epistasis.append([l.mean()['Gamma'], l.median()['Gamma'], ge_lb, ge_ub, 'Gamma', FW, N])  
        print(l.mean()['Gamma'], l.median()['Gamma'], ge_lb, ge_ub, 'Gamma', FW, N)
        full_data = pd.concat((full_data, l))

epistasis = pd.DataFrame(epistasis, columns = ['Statistics_mean' , 'Statistics_median', 'Lower bound', 'Upper bound', 'Type', 'Weight', 'Mutations'])
pickle.dump([epistasis, full_data], open('epistasis_'+str(BQL).replace('.', '_')+'.pckl', 'wb'))
sys.exit()
